{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## preliminary\n",
    "\n",
    "### import some library"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from matplotlib import pyplot as plt\n",
    "import seaborn as sns\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### read data into memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train = pd.read_csv(\"./trainDataset.csv\")\n",
    "df_test = pd.read_csv(\"./testDataset.csv\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### explore data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Survived</th>\n",
       "      <th>PassengerClass</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Age</th>\n",
       "      <th>SiblingSpouse</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>male</td>\n",
       "      <td>adult</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>female</td>\n",
       "      <td>adult</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>female</td>\n",
       "      <td>adult</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>female</td>\n",
       "      <td>adult</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>male</td>\n",
       "      <td>adult</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Survived  PassengerClass     Sex    Age  SiblingSpouse\n",
       "0         0               3    male  adult              1\n",
       "1         1               1  female  adult              1\n",
       "2         1               3  female  adult              0\n",
       "3         1               1  female  adult              1\n",
       "4         0               3    male  adult              0"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Survived</th>\n",
       "      <th>PassengerClass</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Age</th>\n",
       "      <th>SiblingSpouse</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>female</td>\n",
       "      <td>adult</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>female</td>\n",
       "      <td>adult</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>female</td>\n",
       "      <td>adult</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>female</td>\n",
       "      <td>adult</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>female</td>\n",
       "      <td>adult</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Survived  PassengerClass     Sex    Age  SiblingSpouse\n",
       "0         0               3  female  adult              0\n",
       "1         1               3  female  adult              0\n",
       "2         1               1  female  adult              0\n",
       "3         1               1  female  adult              0\n",
       "4         1               2  female  adult              0"
      ]
     },
     "execution_count": 88,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3    110\n",
       "2     51\n",
       "1     39\n",
       "Name: PassengerClass, dtype: int64"
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train[\"PassengerClass\"].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "adult      145\n",
       "teenage     31\n",
       "child       24\n",
       "Name: Age, dtype: int64"
      ]
     },
     "execution_count": 90,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train[\"Age\"].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "male      132\n",
       "female     68\n",
       "Name: Sex, dtype: int64"
      ]
     },
     "execution_count": 91,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train[\"Sex\"].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    129\n",
       "1     71\n",
       "Name: SiblingSpouse, dtype: int64"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train[\"SiblingSpouse\"].value_counts()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Baisic Hunt's Algorithm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### build and train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [],
   "source": [
    "hunt_rules = {}\n",
    "passenger_classes = [1,2,3]\n",
    "ages= [\"adult\", \"teenage\", \"child\"]\n",
    "sexes = [\"male\", \"female\"]\n",
    "sibling_spouses = [0, 1]\n",
    "for passenger_class in passenger_classes:\n",
    "    hunt_rules[passenger_class] = {}\n",
    "    for age in ages:\n",
    "        hunt_rules[passenger_class][age] = {}\n",
    "        for sex in sexes:\n",
    "            hunt_rules[passenger_class][age][sex] = {}\n",
    "            for sibling_spouse in sibling_spouses:\n",
    "                temp = df_train.loc[df_train.apply(lambda x: x[\"PassengerClass\"] == passenger_class and x[\"Age\"] == age and x[\"Sex\"] == sex and x[\"SiblingSpouse\"] == sibling_spouse, axis= 1), \"Survived\"]\n",
    "                num = len(temp)\n",
    "                if num == 0:\n",
    "                    label = -1\n",
    "                else:\n",
    "                    if temp.sum() * 1.0 / num >= 0.5:\n",
    "                        label = 1\n",
    "                    else:\n",
    "                        label = 0\n",
    "                hunt_rules[passenger_class][age][sex][sibling_spouse] = label"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### test on the test dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_test[\"huntLabel\"] = df_test.apply(lambda x: hunt_rules[x[\"PassengerClass\"]][x[\"Age\"]][x[\"Sex\"]][x[\"SiblingSpouse\"]], axis= 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Gini Index with Greedy Strategy Tree"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### build and train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.tree import DecisionTreeClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf = DecisionTreeClassifier(criterion=\"gini\", splitter=\"best\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [],
   "source": [
    "age2num = {\"child\": 0, \"teenage\": 1, \"adult\": 2}\n",
    "sex2num = {\"male\": 0, \"female\": 1}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train = df_train.drop(\"Survived\", axis= 1).copy()\n",
    "x_train[\"Age\"] = x_train[\"Age\"].map(age2num)\n",
    "x_train[\"Sex\"] = x_train[\"Sex\"].map(sex2num)\n",
    "y_train = df_train.loc[:, \"Survived\"].values.reshape(-1,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n",
       "                       max_depth=None, max_features=None, max_leaf_nodes=None,\n",
       "                       min_impurity_decrease=0.0, min_impurity_split=None,\n",
       "                       min_samples_leaf=1, min_samples_split=2,\n",
       "                       min_weight_fraction_leaf=0.0, presort='deprecated',\n",
       "                       random_state=None, splitter='best')"
      ]
     },
     "execution_count": 99,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf.fit(x_train, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### test on the test dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_test = df_test.loc[:, [\"PassengerClass\",\"Sex\",\"Age\",\"SiblingSpouse\"]].copy()\n",
    "x_test[\"Age\"] = x_test[\"Age\"].map(age2num)\n",
    "x_test[\"Sex\"] = x_test[\"Sex\"].map(sex2num)\n",
    "df_test[\"giniLabel\"] = clf.predict(x_test).tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## compare the two algorithms"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### the accuracy of Basic Hunt's Algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7"
      ]
     },
     "execution_count": 101,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(df_test[\"huntLabel\"] == df_test[\"Survived\"]).sum() * 1.0 / df_test.shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### the accuracy of Decision Tree Built with Gini Index and Greedy Strategy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.75"
      ]
     },
     "execution_count": 102,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(df_test[\"giniLabel\"] == df_test[\"Survived\"]).sum() * 1.0 / df_test.shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### conclusion\n",
    "\n",
    "* The full growth decision tree by applying the Gini Index and greedy strategy is better than the tree built by the Basic Hunt's Algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
